#!/usr/bin/env python

'''Tests the driver API for making connections and exercises the networking code'''

from __future__ import print_function

import asyncio, datetime, os, random, socket, sys, tempfile, time

sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir, "common"))
import driver, rdb_unittest, utils

# -- import the rethinkdb driver

r = utils.import_python_driver()

# -- use asyncio subdriver

r.set_loop_type("asyncio")

# -- get settings

rethinkdb_exe = sys.argv[1] if len(sys.argv) > 1 else utils.find_rethinkdb_executable()

# -- shared server

sharedServer       = None
externalServerHost = os.environ.get('RDB_SERVER_HOST', None)
externalServerPort = int(os.environ['RDB_DRIVER_PORT']) if 'RDB_DRIVER_PORT' in os.environ else None

USE_DEFAULT_PORT   = externalServerPort == r.DEFAULT_PORT
USE_DEFAULT_HOST   = externalServerHost in ('localhost', None)

# --

def just_do(coroutine, *args, **kwargs):
    asyncio.get_event_loop().run_until_complete(coroutine(*args, **kwargs))

# == Test Base Classes

class TestCaseAsyncioCompatible(rdb_unittest.TestCaseCompatible):
    # can't use standard TestCase run here because async.
    def run(self, result=None):
        return just_do(self.arun, result)

    @asyncio.coroutine
    def setUp(self):
        return None

    @asyncio.coroutine
    def tearDown(self):
        return None

    @asyncio.coroutine
    def arun(self, result=None):
        if result is None:
            result = self.defaultTestResult()
        result.startTest(self)
        testMethod = getattr(self, self._testMethodName)

        try:
            try:
                yield from self.setUp()
            except KeyboardInterrupt:
                raise
            except:
                result.addError(self, sys.exc_info())
                return

            ok = False
            try:
                yield from testMethod()
                ok = True
            except self.failureException:
                result.addFailure(self, sys.exc_info())
            except KeyboardInterrupt:
                raise
            except:
                result.addError(self, sys.exc_info())

            try:
                yield from self.tearDown()
            except KeyboardInterrupt:
                raise
            except:
                result.addError(self, sys.exc_info())
                ok = False
            if ok:
                result.addSuccess(self)
        finally:
            result.stopTest(self)


class TestWithConnection(TestCaseAsyncioCompatible):
    server  = None
    conn    = None
    host    = None
    port    = None
    
    admin_pass     = None
    private_server = False
    use_tls        = False
    
    @asyncio.coroutine
    def setUp(self):
        global sharedServer
        
        if self.private_server or self.use_tls or self.admin_pass:
            if self.server and self.server == sharedServer:
                self.__class__.server = None
            
            # - validate running server
            if self.server:
                try:
                    self.server.check()
                    sslOption = {'ca_certs':self.server.tlsCertPath} if self.server.tlsCertPath else None
                    conn = yield from r.connect(host=self.server.host, port=self.server.driver_port, ssl=sslOption, password=self.admin_pass or '')
                    yield from r.expr(1).run(conn)
                except Exception as err:
                    # ToDo: figure out how to blame the last test
                    try:
                        self.server.stop()
                    except Exception as e:
                        sys.stderr.write('Got error while shutting down shared server: %s' % str(err))
                    try:
                        yield from self.__class__.conn.close()
                    except Exception: pass
                    self.__class__.server = None
                    self.__class__.conn   = None
            
            # - setup private server
            if self.server is None:
                sys.stdout.write('new private server... ')
                sys.stdout.flush()
                self.__class__.server = driver.Process(executable_path=rethinkdb_exe, console_output=self.__class__.__name__ + '_server_console.txt', wait_until_ready=True, tls=self.use_tls)
                
                if self.admin_pass is not None:
                    sslOption = {'ca_certs':self.server.tlsCertPath} if self.server.tlsCertPath else None
                    conn = yield from r.connect(host=self.server.host, port=self.server.driver_port, ssl=sslOption)
                    result = yield from r.db('rethinkdb').table('users').get('admin').update({'password':self.admin_pass}).run(conn)
                    if result != {'skipped': 0, 'deleted': 0, 'unchanged': 0, 'errors': 0, 'replaced': 1, 'inserted': 0}:
                        self.__class__.server = None
                        raise Exception('Unable to set admin password, got: %s' % str(result))
            
            self.__class__.host   = self.server.host
            self.__class__.port   = self.server.driver_port
        
        elif not any([externalServerHost, externalServerPort]):
            if self.server and self.server != sharedServer:
                self.__class__.server = None
            
            # - validate running server
            if sharedServer is not None:
                try:
                    sharedServer.check()
                    sslOption = {'ca_certs':self.server.tlsCertPath} if self.server and self.server.tlsCertPath else None
                    conn = yield from r.connect(host=sharedServer.host, port=sharedServer.driver_port, ssl=sslOption, password=self.admin_pass or '')
                    yield from r.expr(1).run(conn)
                except Exception as e:
                    # ToDo: figure out how to blame the last test
                    try:
                        sharedServer.stop()
                    except Exception as e:
                        sys.stderr.write('Got error while shutting down shared server: %s' % str(e))
                    try:
                        yield from self.__class__.conn.close()
                    except Exception: pass
                    sharedServer = None
                    self.__class__.server = None
                    self.__class__.conn   = None
            
            # - setup shared server
            if sharedServer is None:
                sys.stdout.write('new shared server... ')
                sys.stdout.flush()
                sharedServer = driver.Process(executable_path=rethinkdb_exe, console_output=False, wait_until_ready=True, tls=self.use_tls)
                
            self.__class__.server = sharedServer
            self.__class__.host   = self.server.host
            self.__class__.port   = self.server.driver_port
        
        else:
            # - use external server
            
            self.__class__.server = None
            self.__class__.host   = externalServerHost or 'localhost'
            self.__class__.port   = externalServerPort or r.DEFAULT_PORT
        
        # - establish a connection
        
        if self.conn:
            try:
                yield from r.expr(1).run(self.conn)
            except Exception:
                self.__class__.conn = None
        if not self.conn:
            sslOption = {'ca_certs':self.server.tlsCertPath} if self.server and self.server.tlsCertPath else None
            self.__class__.conn = yield from r.connect(host=self.host, port=self.port, ssl=sslOption, password=self.admin_pass or '')
        
        # - ensure `test` db exists and is empty
        
        yield from r.expr(['test']).set_difference(r.db_list()).for_each(r.db_create(r.row)).run(self.conn)
        yield from r.db('test').table_list().for_each(r.db('test').table_drop(r.row)).run(self.conn)

# == Test Classes

class TestNoConnection(TestCaseAsyncioCompatible):
    @asyncio.coroutine
    def test_connect_port(self):
        port = utils.get_avalible_port()
        with self.assertRaisesRegexp(r.ReqlDriverError, "Could not connect to localhost:%d." % port):
            yield from r.connect(port=port)
    
    @asyncio.coroutine
    def test_connect_timeout(self):
        '''Test that we get a ReQL error if we connect to a non-responsive port'''
        useSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        useSocket.bind(('localhost', 0))
        useSocket.listen(0)

        host, port = useSocket.getsockname()

        try:
            with self.assertRaisesRegexp(r.ReqlDriverError, "Connection interrupted during handshake with %s:%d. Error: Operation timed out." % (host, port)):
                yield from r.connect(host=host, port=port, timeout=2)
        finally:
            useSocket.close()

    @asyncio.coroutine
    def test_connect_host(self):
        port = utils.get_avalible_port()
        with self.assertRaisesRegexp(r.ReqlDriverError, "Could not connect to 0.0.0.0:%d." % port):
            yield from r.connect(host="0.0.0.0", port=port)

    @asyncio.coroutine
    def test_empty_run(self):
        # Test the error message when we pass nothing to run and
        # didn't call `repl`
        with self.assertRaisesRegexp(r.ReqlDriverError,
                                "RqlQuery.run must be given"
                                " a connection to run on."):
             yield from r.expr(1).run()

class TestConnectionDefaults(TestWithConnection):
    if USE_DEFAULT_PORT:
        @asyncio.coroutine
        def test_connect_default_port(self):
            conn = yield from r.connect(host=self.host)
            yield from conn.reconnect()
        
        @asyncio.coroutine
        def test_connect_default_port_wrong_auth(self):
            with self.assertRaisesRegexp(r.ReqlAuthError, "Could not connect to %s:%d: Wrong password" % (self.host, self.port)):
                yield from r.connect(host=self.host, password="hunter2")

    if USE_DEFAULT_HOST:
        @asyncio.coroutine
        def test_connect_default_host(self):
            conn = yield from r.connect(port=self.port)
            yield from conn.reconnect()
        
        @asyncio.coroutine
        def test_connect_default_host_wrong_auth(self):
            with self.assertRaisesRegexp(r.ReqlAuthError, "Could not connect to %s:%d: Wrong password" % (self.host, self.port)):
                yield from r.connect(port=self.port, password="hunter2")
    
    if USE_DEFAULT_PORT and USE_DEFAULT_HOST:
        @asyncio.coroutine
        def test_connect_default_host_port(self):
            conn = yield from r.connect()
            yield from conn.reconnect()
        
        @asyncio.coroutine
        def test_connect_wrong_auth(self):
            with self.assertRaisesRegexp(r.ReqlAuthError, "Could not connect to %s:%d: Wrong password" % (self.host, self.port)):
                yield from r.connect(password="hunter2")


class TestAuthConnection(TestWithConnection):
    incorrectAuthMessage = "Could not connect to %s:%d: Wrong password"
    admin_pass = 'hunter2'
    
    @asyncio.coroutine
    def test_connect_no_auth(self):
        with self.assertRaisesRegexp(r.ReqlAuthError, self.incorrectAuthMessage % (self.host, self.port)):
            yield from r.connect(host=self.host, port=self.port)
    
    @asyncio.coroutine
    def test_connect_wrong_auth(self):
        with self.assertRaisesRegexp(r.ReqlAuthError, self.incorrectAuthMessage % (self.host, self.port)):
            yield from r.connect(host=self.host, port=self.port, auth_key="")
        with self.assertRaisesRegexp(r.ReqlAuthError, self.incorrectAuthMessage % (self.host, self.port)):
            yield from r.connect(host=self.host, port=self.port, auth_key="hunter3")
        with self.assertRaisesRegexp(r.ReqlAuthError, self.incorrectAuthMessage % (self.host, self.port)):
            yield from r.connect(host=self.host, port=self.port, auth_key="hunter22")
    
    @asyncio.coroutine
    def test_connect_correct_auth(self):
        conn = yield from r.connect(port=self.port, auth_key="hunter2")
        yield from conn.reconnect()


class TestConnection(TestWithConnection):
    @asyncio.coroutine
    def test_client_port_and_address(self):
        self.assertIsNotNone(self.conn.client_port())
        self.assertIsNotNone(self.conn.client_address())

        yield from self.conn.close()

        self.assertIsNone(self.conn.client_port())
        self.assertIsNone(self.conn.client_address())

    @asyncio.coroutine
    def test_connect_close_reconnect(self):
        yield from r.expr(1).run(self.conn)
        yield from self.conn.close()
        yield from self.conn.close()
        yield from self.conn.reconnect()
        yield from r.expr(1).run(self.conn)

    @asyncio.coroutine
    def test_connect_close_expr(self):
        yield from r.expr(1).run(self.conn)
        yield from self.conn.close()
        with self.assertRaisesRegexp(r.ReqlDriverError, "Connection is closed."):
            yield from r.expr(1).run(self.conn)

    @asyncio.coroutine
    def test_noreply_wait_waits(self):
        t = time.time()
        yield from r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        yield from self.conn.noreply_wait()
        duration = time.time() - t
        self.assertGreaterEqual(duration, 0.5)

    @asyncio.coroutine
    def test_close_waits_by_default(self):
        t = time.time()
        yield from r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        yield from self.conn.close()
        duration = time.time() - t
        self.assertGreaterEqual(duration, 0.5)

    @asyncio.coroutine
    def test_reconnect_waits_by_default(self):
        t = time.time()
        yield from r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        yield from self.conn.reconnect()
        duration = time.time() - t
        self.assertGreaterEqual(duration, 0.5)

    @asyncio.coroutine
    def test_close_does_not_wait_if_requested(self):
        t = time.time()
        yield from r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        yield from self.conn.close(noreply_wait=False)
        duration = time.time() - t
        self.assertLess(duration, 0.5)

    @asyncio.coroutine
    def test_reconnect_does_not_wait_if_requested(self):
        t = time.time()
        yield from r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        yield from self.conn.reconnect(noreply_wait=False)
        duration = time.time() - t
        self.assertLess(duration, 0.5)

    @asyncio.coroutine
    def test_db(self):
        sslOption = {'ca_certs':self.server.tlsCertPath} if self.server and self.server.tlsCertPath else None
        c = yield from r.connect(host=self.host, port=self.port, ssl=sslOption, password=self.admin_pass or '')
        
        yield from r.db('test').table_create('t1').run(c)

        yield from r.expr(['db2']).set_difference(r.db_list()).for_each(r.db_create(r.row)).run(c)
        yield from r.expr(['t2']).set_intersection(r.db('db2').table_list()).for_each(r.db('db2').table_drop(r.row)).run(c)
        yield from r.db('db2').table_create('t2').run(c)
        
        # Default db should be 'test' so this will work
        yield from r.table('t1').run(c)

        # Use a new database
        c.use('db2')
        yield from r.table('t2').run(c)
        with self.assertRaisesRegexp(r.ReqlRuntimeError, "Table `db2.t1` does not exist."):
            yield from r.table('t1').run(c)

        c.use('test')
        yield from r.table('t1').run(c)
        with self.assertRaisesRegexp(r.ReqlRuntimeError, "Table `test.t2` does not exist."):
            yield from r.table('t2').run(c)

        yield from c.close()

        # Test setting the db in connect
        c = yield from r.connect(db='db2', host=self.host, port=self.port, ssl=sslOption, password=self.admin_pass or '')
        yield from r.table('t2').run(c)

        with self.assertRaisesRegexp(r.ReqlRuntimeError, "Table `db2.t1` does not exist."):
            yield from r.table('t1').run(c)

        yield from c.close()

        # Test setting the db as a `run` option
        yield from r.table('t2').run(self.conn, db='db2')

    @asyncio.coroutine
    def test_outdated_read(self):
        yield from r.db('test').table_create('t1').run(self.conn)

        # Use outdated is an option that can be passed to db.table or `run`
        # We're just testing here if the server actually accepts the option.

        yield from r.table('t1', read_mode='outdated').run(self.conn)
        yield from r.table('t1').run(self.conn, read_mode='outdated')

    @asyncio.coroutine
    def test_repl(self):
        try:
            # Calling .repl() should set this connection as global state to be used when `run` is not otherwise passed a connection.
            sslOption = {'ca_certs':self.server.tlsCertPath} if self.server and self.server.tlsCertPath else None
            c = yield from r.connect(host=self.host, port=self.port, ssl=sslOption, password=self.admin_pass or '')
            c.repl()
            yield from r.expr(1).run()
            
            c.repl() # is idempotent
            yield from r.expr(1).run()
            
            yield from c.close()
            with self.assertRaisesRegexp(r.ReqlDriverError, "Connection is closed."):
                yield from r.expr(1).run()
        finally:
            r.Repl.clear()

    @asyncio.coroutine
    def test_port_conversion(self):
        with self.assertRaisesRegexp(r.ReqlDriverError, "Could not convert port 'abc' to an integer."):
            yield from r.connect(port='abc', host=self.host)
    
    @asyncio.coroutine
    def test_shutdown(self):
        # use our own server for simplicity
        with driver.Process(executable_path=rethinkdb_exe, console_output=False, wait_until_ready=True, tls=self.use_tls) as server:
            sslOption = {'ca_certs':server.tlsCertPath} if server and server.tlsCertPath else None
            c = yield from r.connect(host=server.host, port=server.driver_port, ssl=sslOption, password=self.admin_pass or '')
            yield from r.expr(1).run(c)
    
            server.stop()
            yield from asyncio.sleep(0.2)
    
            with self.assertRaisesRegexp(r.ReqlDriverError, "Connection is closed."):
                yield from r.expr(1).run(c)
    
            yield from c.close(noreply_wait=False)


class TestTLSConnection(TestConnection):
    use_tls = True


class TestBatching(TestWithConnection):

    @asyncio.coroutine
    def test_get_intersecting_batching(self):
        '''Test that get_intersecting() batching works properly'''
        
        yield from r.db('test').table_create('t1').run(self.conn)
        t1 = r.db('test').table('t1')

        yield from t1.index_create('geo', geo=True).run(self.conn)
        yield from t1.index_wait('geo').run(self.conn)

        batch_size = 3
        point_count = 500
        poly_count = 500
        get_tries = 10

        # Insert a couple of random points, so we get a well distributed range of
        # secondary keys. Also insert a couple of large-ish polygons, so we can
        # test filtering of duplicates on the server.
        rseed = random.getrandbits(64)
        random.seed(rseed)
        print("Random seed: " + str(rseed) + ' ...', end=' ')
        sys.stdout.flush()

        yield from t1.insert(
            [{'geo': r.point(random.uniform(-180.0, 180.0), random.uniform(-90.0, 90.0))} for _ in range(point_count)]
        ).run(self.conn)
        
        yield from t1.insert(
            [{'geo': r.circle([random.uniform(-180.0, 180.0), random.uniform(-90.0, 90.0)], 1000000)} for _ in range(poly_count)]
        ).run(self.conn)

        # Check that the results are actually lazy at least some of the time
        # While the test is randomized, chances are extremely high to get a lazy result at least once.
        seen_lazy = False

        for i in range(0, get_tries):
            query_circle = r.circle([random.uniform(-180.0, 180.0), random.uniform(-90.0, 90.0)], 8000000)
            reference = yield from t1.filter(r.row['geo'].intersects(query_circle)).coerce_to("ARRAY").run(self.conn)
            cursor = yield from t1.get_intersecting(query_circle, index='geo').run(self.conn, max_batch_rows=batch_size)
            if cursor.error is None:
                seen_lazy = True

            while len(reference) > 0:
                row = yield from cursor.next()
                self.assertEqual(reference.count(row), 1)
                reference.remove(row)
            with self.assertRaises(r.ReqlCursorEmpty):
                yield from cursor.next()

        self.assertTrue(seen_lazy)
    
    def test_batching(self):
        '''Test the cursor API when there is exactly mod batch size elements in the result stream'''
        
        yield from r.db('test').table_create('t1').run(self.conn)
        t1 = r.table('t1')

        batch_size = 3
        count = 500

        ids = set(range(0, count))

        yield from t1.insert([{'id':i} for i in ids]).run(self.conn)
        cursor = yield from t1.run(self.conn, max_batch_rows=batch_size)

        for _ in range(0, count - 1):
            row = yield from cursor.next()
            self.assertTrue(row['id'] in ids)
            ids.remove(row['id'])

        self.assertEqual((yield from cursor.next())['id'], ids.pop())
        with self.assertRaises(r.ReqlCursorEmpty):
            yield from cursor.next()


class TestGroupWithTimeKey(TestWithConnection):
    @asyncio.coroutine
    def runTest(self):
        yield from r.db('test').table_create('times').run(self.conn)
        
        time1 = 1375115782.24
        rt1 = r.epoch_time(time1).in_timezone('+00:00')
        dt1 = datetime.datetime.fromtimestamp(time1, r.ast.RqlTzinfo('+00:00'))
        time2 = 1375147296.68
        rt2 = r.epoch_time(time2).in_timezone('+00:00')
        dt2 = datetime.datetime.fromtimestamp(time2, r.ast.RqlTzinfo('+00:00'))

        res = yield from r.table('times').insert({'id': 0, 'time': rt1}).run(self.conn)
        self.assertEqual(res['inserted'], 1)
        res = yield from r.table('times').insert({'id': 1, 'time': rt2}).run(self.conn)
        self.assertEqual(res['inserted'], 1)

        expected_row1 = {'id':0, 'time':dt1}
        expected_row2 = {'id':1, 'time':dt2}

        groups = yield from r.table('times').group('time').coerce_to('array').run(self.conn)
        self.assertEqual(groups, {dt1:[expected_row1], dt2:[expected_row2]})


class TestSuccessAtomFeed(TestWithConnection):
    @asyncio.coroutine
    def runTest(self):
        yield from r.db('test').table_create('success_atom_feed').run(self.conn)
        t1 = r.db('test').table('success_atom_feed')

        res = yield from t1.insert({'id': 0, 'a': 16}).run(self.conn)
        self.assertEqual(res['inserted'], 1)
        res = yield from t1.insert({'id': 1, 'a': 31}).run(self.conn)
        self.assertEqual(res['inserted'], 1)

        yield from t1.index_create('a', lambda x: x['a']).run(self.conn)
        yield from t1.index_wait('a').run(self.conn)

        changes = yield from t1.get(0).changes(include_initial=True).run(self.conn)
        self.assertTrue(changes.error is None)
        self.assertEqual(len(changes.items), 1)

class TestCursor(TestWithConnection):
    @asyncio.coroutine
    def test_type(self):
        cursor = yield from r.range().run(self.conn)
        self.assertTrue(isinstance(cursor, r.Cursor))

    @asyncio.coroutine
    def test_cursor_after_connection_close(self):
        cursor = yield from r.range().run(self.conn)
        yield from self.conn.close()

        @asyncio.coroutine
        def read_cursor(cursor):
            while (yield from cursor.fetch_next()):
                yield from cursor.next()
                cursor.close()

        with self.assertRaisesRegexp(r.ReqlRuntimeError, "Connection is closed."):
             yield from read_cursor(cursor)

    @asyncio.coroutine
    def test_cursor_after_cursor_close(self):
        cursor = yield from r.range().run(self.conn)
        yield from cursor.close()
        count = 0
        while (yield from cursor.fetch_next()):
            yield from cursor.next()
            count += 1
        self.assertNotEqual(count, 0, "Did not get any cursor results")

    @asyncio.coroutine
    def test_cursor_close_in_each(self):
        cursor = yield from r.range().run(self.conn)
        count = 0

        while (yield from cursor.fetch_next()):
            yield from cursor.next()
            count += 1
            if count == 2:
                yield from cursor.close()

        self.assertTrue(count >= 2, "Did not get enough cursor results")

    @asyncio.coroutine
    def test_cursor_success(self):
        range_size = 10000
        cursor = yield from r.range().limit(range_size).run(self.conn)
        count = 0
        while (yield from cursor.fetch_next()):
            yield from cursor.next()
            count += 1
        self.assertEqual(count, range_size,
             "Expected %d results on the cursor, but got %d" % (range_size, count))

    @asyncio.coroutine
    def test_cursor_double_each(self):
        range_size = 10000
        cursor = yield from r.range().limit(range_size).run(self.conn)
        count = 0

        while (yield from cursor.fetch_next()):
            yield from cursor.next()
            count += 1
        self.assertEqual(count, range_size,
             "Expected %d results on the cursor, but got %d" % (range_size, count))

        while (yield from cursor.fetch_next()):
            yield from cursor.next()
            count += 1
        self.assertEqual(count, range_size,
             "Expected no results on the second iteration of the cursor, but got %d" % (count - range_size))

    # Used in wait tests
    num_cursors=3

    @asyncio.coroutine
    def do_wait_test(self, wait_time):
        cursors = [ ]
        cursor_counts = [ ]
        cursor_timeouts = [ ]
        for i in range(self.num_cursors):
            cur = yield from r.range().map(r.js("(function (row) {" +
                                                "end = new Date(new Date().getTime() + 2);" +
                                                "while (new Date() < end) { }" +
                                                "return row;" +
                                                "})")).run(self.conn, max_batch_rows=500)
            cursors.append(cur)
            cursor_counts.append(0)
            cursor_timeouts.append(0)

        @asyncio.coroutine
        def get_next(cursor_index):
            try:
                if wait_time is None: # Special case to use the default
                    yield from cursors[cursor_index].next()
                else:
                    yield from cursors[cursor_index].next(wait=wait_time)
                cursor_counts[cursor_index] += 1
            except r.ReqlTimeoutError:
                cursor_timeouts[cursor_index] += 1

        # We need to get ahead of pre-fetching for this to get the error we want
        while sum(cursor_counts) < 4000:
            for cursor_index in range(self.num_cursors):
                for read_count in range(random.randint(0, 10)):
                    yield from get_next(cursor_index)

        [cursor.close() for cursor in cursors]

        return (sum(cursor_counts), sum(cursor_timeouts))

    @asyncio.coroutine
    def test_false_wait(self):
        reads, timeouts = yield from self.do_wait_test(False)
        self.assertNotEqual(timeouts, 0, "Did not get timeouts using zero (false) wait.")

    @asyncio.coroutine
    def test_zero_wait(self):
        reads, timeouts = yield from self.do_wait_test(0)
        self.assertNotEqual(timeouts, 0, "Did not get timeouts using zero wait.")

    @asyncio.coroutine
    def test_short_wait(self):
        reads, timeouts = yield from self.do_wait_test(0.0001)
        self.assertNotEqual(timeouts, 0, "Did not get timeouts using short (100 microsecond) wait.")

    @asyncio.coroutine
    def test_long_wait(self):
        reads, timeouts = yield from self.do_wait_test(10)
        self.assertEqual(timeouts, 0, "Got timeouts using long (10 second) wait.")

    @asyncio.coroutine
    def test_infinite_wait(self):
        reads, timeouts = yield from self.do_wait_test(True)
        self.assertEqual(timeouts, 0, "Got timeouts using infinite wait.")

    @asyncio.coroutine
    def test_default_wait(self):
        reads, timeouts = yield from self.do_wait_test(None)
        self.assertEqual(timeouts, 0, "Got timeouts using default (infinite) wait.")

    # This test relies on the internals of the TornadoCursor implementation
    @asyncio.coroutine
    def test_rate_limit(self):
        # Get the first batch
        cursor = yield from r.range().run(self.conn)
        cursor_initial_size = len(cursor.items)

        # Wait for the second (pre-fetched) batch to arrive
        yield from cursor.new_response
        cursor_new_size = len(cursor.items)

        self.assertLess(cursor_initial_size, cursor_new_size)

        # Wait and observe that no third batch arrives
        with self.assertRaises(asyncio.TimeoutError):
            yield from asyncio.wait_for(asyncio.shield(cursor.new_response), 2)
        self.assertEqual(cursor_new_size, len(cursor.items))

    # Test that an error on a cursor (such as from closing the connection)
    # properly wakes up waiters immediately
    @asyncio.coroutine
    def test_cursor_error(self):
        cursor = yield from r.range() \
            .map(lambda row:
                r.branch(row <= 5000, # High enough for multiple batches
                         row,
                         r.js('while(true){ }'))) \
            .run(self.conn)

        @asyncio.coroutine
        def read_cursor(cursor, hanging):
            try:
                while True:
                    yield from cursor.next(wait=1)
            except r.ReqlTimeoutError:
                pass
            hanging.set_result(True)
            yield from cursor.next()

        @asyncio.coroutine
        def read_wrapper(cursor, done, hanging):
            try:
                with self.assertRaisesRegexp(r.ReqlRuntimeError, 'Connection is closed.'):
                    yield from read_cursor(cursor, hanging)
                done.set_result(None)
            except Exception as ex:
                if not cursor_hanging.done():
                    cursor_hanging.set_exception(ex)
                done.set_exception(ex)

        cursor_hanging = asyncio.Future()
        done = asyncio.Future()
        asyncio.async(read_wrapper(cursor, done, cursor_hanging))

        # Wait for the cursor to hit the hang point before we close and cause an error
        yield from cursor_hanging
        yield from self.conn.close()
        yield from done

class TestChangefeeds(TestWithConnection):
    @asyncio.coroutine
    def setUp(self):
        yield from super(TestChangefeeds, self).setUp()
        yield from r.expr(['a', 'b']).for_each(r.db('test').table_create(r.row)).run(self.conn)

    @asyncio.coroutine
    def table_a_even_writer(self):
        for i in range(10):
            yield from r.db('test').table("a").insert({"id": i * 2}).run(self.conn)

    @asyncio.coroutine
    def table_a_odd_writer(self):
        for i in range(10):
            yield from r.db('test').table("a").insert({"id": i * 2 + 1}).run(self.conn)

    @asyncio.coroutine
    def table_b_writer(self):
        for i in range(10):
            yield from r.db('test').table("b").insert({"id": i}).run(self.conn)

    @asyncio.coroutine
    def cfeed_noticer(self, table, ready, done, needed_values):
        feed = yield from r.db('test').table(table).changes(squash=False).run(self.conn)
        try:
            ready.set_result(None)
            while len(needed_values) != 0 and (yield from feed.fetch_next()):
                item = yield from feed.next()
                self.assertIsNone(item['old_val'])
                self.assertIn(item['new_val']['id'], needed_values)
                needed_values.remove(item['new_val']['id'])
            done.set_result(None)
        except Exception as ex:
            done.set_exception(ex)

    @asyncio.coroutine
    def test_multiple_changefeeds(self):
        feeds_ready = { }
        feeds_done = { }
        needed_values = { 'a': set(range(20)), 'b': set(range(10)) }
        for n in ('a', 'b'):
            feeds_ready[n] = asyncio.Future()
            feeds_done[n] = asyncio.Future()
            asyncio.async(self.cfeed_noticer(n, feeds_ready[n], feeds_done[n], needed_values[n]))

        yield from asyncio.wait(feeds_ready.values())
        yield from asyncio.wait([self.table_a_even_writer(),
               self.table_a_odd_writer(),
               self.table_b_writer()])
        yield from asyncio.wait(feeds_done.values())
        self.assertTrue(all([len(x) == 0 for x in needed_values.values()]))

if __name__ == '__main__':
    rdb_unittest.main()
